'''
pkl2mot.py

Convert joint tracking pickle recorded in pybullet (.pkl) to OpenIm motion (.mot) file

Compatiable with names equilvalence for original pickles in NeuroMechFly repo (https://github.com/NeLy-EPFL/NeuroMechFly/tree/main/data/joint_tracking)

input:
The .pkl file to convert 

output:
The name of the .mot file to export 

usage:

pkl files stored in data/joint tracking
python pkl2mot.py -i grooming_nmf.pkl -o grooming_opensim.mot
python pkl2mot.py -i walking_nmf.pkl -o walking_opensim.mot
'''

import numpy as np
import pandas as pd
import argparse

parser = argparse.ArgumentParser(description='Translate joint tracking pickle to osim motion')
parser.add_argument('-i', '--input', required=True) 
parser.add_argument('-o', '--output', required=True) 
args = parser.parse_args()

data_path = args.input
filename = args.output

names_equivalence = {
    'ThC_pitch': 'Coxa_pitch',
    'ThC_yaw': 'Coxa_yaw',
    'ThC_roll': 'Coxa_roll',
    'CTr_pitch': 'Femur_pitch',
    'CTr_roll': 'Femur_roll',
    'FTi_pitch': 'Tibia_pitch',
    'TiTa_pitch': 'Tarsus1_pitch'
}
converted_dict = {}

data = pd.read_pickle(data_path)
start = 0
print(data.items())
for leg, joints in data.items():
    for joint_name, val in joints.items():
        new_name = 'joint_' + leg[:2] + names_equivalence[joint_name]
        converted_dict[new_name] = val[start:]

d = converted_dict
angles = []
for k,v in d.items():
    steps = len(v)
    angles.append(v)
angle_count = len(angles)
step_time=5e-4

with open(filename, 'w') as fo:
    fo.write("Coordinates\n")
    fo.write("version=1\n")
    fo.write("nRows={}\n".format(steps))
    fo.write("nColumns={}\n".format(angle_count+1))
    fo.write("inDegrees=yes\n")
    fo.write("endheader\n")
    fo.write("{:>16}\t".format('time'))
    for k in d.keys():
        c = k
        fo.write("{:>16}\t".format(k))
    fo.write("\n")
    for i in range(steps):
        fo.write("{:16.8f}\t".format(step_time * i))
        for j in range(angle_count):
            fo.write("{:16.8f}\t".format(angles[j][i]*180/np.pi))
        fo.write("\n")
